import os
import numpy as np
import torch
import shutil
import torchvision.transforms as transforms
from torch.autograd import Variable
from PIL import Image
import timeit
import random


class AvgrageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.avg = 0
        self.sum = 0
        self.cnt = 0

    def update(self, val, n=1):
        self.sum += val * n
        self.cnt += n
        self.avg = self.sum / self.cnt


def accuracy(output, target, topk=(1,)):
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


class Cutout(object):
    def __init__(self, length):
        self.length = length

    def __call__(self, img):
        h, w = img.size(1), img.size(2)
        mask = np.ones((h, w), np.float32)
        y = np.random.randint(h)
        x = np.random.randint(w)

        y1 = np.clip(y - self.length // 2, 0, h)
        y2 = np.clip(y + self.length // 2, 0, h)
        x1 = np.clip(x - self.length // 2, 0, w)
        x2 = np.clip(x + self.length // 2, 0, w)

        mask[y1: y2, x1: x2] = 0.
        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img *= mask
        return img


def _data_transforms_cifar10(args):
    CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
    CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])

    if args.cutout:
        train_transform.transforms.append(Cutout(args.cutout_length))

    train_transform.transforms.append(transforms.Normalize(CIFAR_MEAN, CIFAR_STD))

    valid_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])
    return train_transform, valid_transform


def count_parameters_in_MB(model):

    n_params_from_auxiliary_head = np.sum(np.prod(v.size()) for name, v in model.named_parameters()) - \
                                   np.sum(np.prod(v.size()) for name, v in model.named_parameters()
                                          if "auxiliary" not in name)
    n_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return (n_params_trainable - n_params_from_auxiliary_head) / 1e6


def save_checkpoint(state, is_best, save):
    filename = os.path.join(save, 'checkpoint.pth.tar')
    torch.save(state, filename)
    if is_best:
        best_filename = os.path.join(save, 'model_best.pth.tar')
        shutil.copyfile(filename, best_filename)

# this save is just for simple weights and is not portable
def save(model, model_path):
    torch.save(model.state_dict(), model_path)

# this load is just for simple weights and is not portable 
def load(model, model_path):
    model.load_state_dict(torch.load(model_path))


def drop_path(x, drop_prob):
    if drop_prob > 0.:
        keep_prob = 1. - drop_prob
        mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob))
        x.div_(keep_prob)
        x.mul_(mask)
    return x


def create_exp_dir(path, scripts_to_save=None):
    
    print('Experiment dir : {}'.format(path))
    if not os.path.exists(path):
        os.mkdir(path)

    if scripts_to_save is not None:
        os.mkdir(os.path.join(path, 'scripts'))
        for script in scripts_to_save:
            dst_file = os.path.join(path, 'scripts', os.path.basename(script))
            shutil.copyfile(script, dst_file)


# Gets image shape and color channels from first image in dataset
def get_image_dim(DIRECTORY, EXT):
    images = os.scandir(DIRECTORY)
    files = []    
    for i in images:
        if i.name.endswith("."+EXT) and not i.name.startswith("."):
            files.append(i.path)
            break
    
    # Open one image to get dimensions from
    im = Image.open(files[0]) 
    c = len(im.getbands())
    
    return im.size, c

def assign_gpu(how='random'):
    if how == 'random':
        num_gpus = torch.cuda.device_count() - 1
        gpu = random.randint(0, num_gpus)
    if how == 'by_utilization':
        min_utilization = 1000
        for device in range(torch.cuda.device_count()):
            utilization = torch.cuda.utilization(device)
            print("device: {}, utilization: {}".format(device, utilization))
            if utilization < min_utilization:
                min_utilization = utilization
                gpu = device
    else:
        raise NotImplementedError("there are no options besides random assignment at the moment.")
    return gpu 

def num_gpus():
    return torch.cuda.device_count()

def gpu_memory_usage(device):
    start = timeit.timeit()
    print(torch.cuda.utilization(device))
    print(torch.cuda.mem_get_info(device))
    print(torch.cuda.memory_summary(device))
    print(torch.cuda.memory_usage(device))
    end = timeit.timeit()
    print(end - start)
    return
